package cn.com.yd.commons.grpc;

import cn.com.yd.commons.grpc.config.ChannelFactory;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.handler.codec.http2.Http2SecurityUtil;
import io.netty.handler.ssl.*;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;

import java.util.Iterator;

/**
 * Grpc 客户端
 *
 * @author 李庆海
 */
public class GrpcClient {

    private GenericObjectPool<ManagedChannel> channelPool;

    /**
     * 使用指定的连接池进行初始化
     *
     * @param server 服务器IP地址
     * @param port   端口号
     * @param ssl    是否使用ssl
     * @Param poolConfig 连接池配置，使用new GenericObjectPoolConfig()创建连接池对象
     */
    public GrpcClient(String server, int port, boolean ssl, GenericObjectPoolConfig poolConfig) {
        if (null == poolConfig) {
            poolConfig = new GenericObjectPoolConfig();
        }
        try {
            ManagedChannelBuilder managedChannelBuilder = null;
            if (ssl) {
                SslContext context = SslContextBuilder.forClient().sslProvider(OpenSsl.isAlpnSupported() ? SslProvider.OPENSSL : SslProvider.JDK)
                        .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE).trustManager(InsecureTrustManagerFactory.INSTANCE)
                        .applicationProtocolConfig(new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.ALPN, ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
                                ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, ApplicationProtocolNames.HTTP_2, ApplicationProtocolNames.HTTP_1_1))
                        .build();
                managedChannelBuilder = NettyChannelBuilder.forAddress(server, port).sslContext(context);
            } else {
                // 创建信道
                managedChannelBuilder = ManagedChannelBuilder.forAddress(server, port).usePlaintext();
            }
            this.channelPool = new GenericObjectPool<ManagedChannel>(new ChannelFactory(managedChannelBuilder), poolConfig);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 使用默认的连接池进行初始化
     *
     * @param server 服务器IP地址
     * @param port   端口号
     * @param ssl    是否使用ssl
     */
    public GrpcClient(String server, int port, boolean ssl) {
        this(server, port, ssl, null);
    }

    /**
     * 关闭方法，不应使用会抛出不支持的操作异常
     */
    public void shutdown() throws Exception {
        throw new UnsupportedOperationException();
    }

    /**
     * 基本的同步方法
     *
     * @param req
     * @return
     * @throws Exception
     */
    public Response handle(Request req) throws Exception {
        Object obj = subdo(new Do() {
            @Override
            public Object indo(ManagedChannel channel) {
                CommonServiceGrpc.CommonServiceBlockingStub commonServiceBlockingStub = CommonServiceGrpc.newBlockingStub(channel);
                return commonServiceBlockingStub.handle(req);
            }
        });
        return (Response) obj;
    }

    /**
     * 服务端流式
     *
     * @param req
     * @return
     * @throws Exception
     */
    public Iterator<Response> serverStreamingHandle(Request req) throws Exception {
        Object obj = subdo(new Do() {
            @Override
            public Object indo(ManagedChannel channel) {
                CommonServiceGrpc.CommonServiceBlockingStub commonServiceBlockingStub = CommonServiceGrpc.newBlockingStub(channel);
                return commonServiceBlockingStub.serverStreamingHandle(req);
            }
        });
        return (Iterator<Response>) obj;
    }

    /**
     * 客户端流式
     *
     * @param responseObserver
     * @return
     * @throws Exception
     */
    public StreamObserver<Request> clientStreamingHandle(StreamObserver<Response> responseObserver) throws Exception {
        Object obj = subdo(new Do() {
            @Override
            public Object indo(ManagedChannel channel) {
                CommonServiceGrpc.CommonServiceStub commonServiceStub = CommonServiceGrpc.newStub(channel);
                return commonServiceStub.clientStreamingHandle(responseObserver);
            }
        });
        return (StreamObserver<Request>) obj;
    }

    /**
     * 消除重复代码
     *
     * @param how Do的业务实现类
     * @return
     * @throws Exception
     */
    private Object subdo(Do how) throws Exception {
        Object rst = null;
        ManagedChannel channel = null;
        try {
            channel = this.channelPool.borrowObject();
            rst = how.indo(channel);
        } finally {
            if (channel != null) {
                this.channelPool.returnObject(channel);
            }
        }
        return rst;
    }

    /**
     * 双向流式
     *
     * @param responseObserver
     * @return
     * @throws Exception
     */
    public StreamObserver<Request> bidirectionalStreamingHandle(StreamObserver<Response> responseObserver) throws Exception {
        Object obj = subdo(new Do() {
            @Override
            public Object indo(ManagedChannel channel) {
                CommonServiceGrpc.CommonServiceStub commonServiceStub = CommonServiceGrpc.newStub(channel);
                return commonServiceStub.bidirectionalStreamingHandle(responseObserver);
            }
        });
        return (StreamObserver<Request>) obj;
    }
}

/**
 * 具体业务处理接口
 */
interface Do {
    /**
     * 具体的业务处理
     *
     * @param channel
     * @return
     */
    Object indo(ManagedChannel channel);
}